%% Load and clean data
clear ;

% Load data
load catalog_data.mat ;

% The first four columns of x are a complete set of dummy variables for
% something.  If you sum the four columns, you get a column of ones.  We'll
% drop one of these columns from x and one from dx.

x = x(:,2:end);
dx = dx(:,2:end);
xnames = xnames(2:end);
dxnames = dxnames(2:end);


% There are a bunch of columns that are effectively all zeros.
% One is literally all zero, and several have only one observation.
% Drop these columns and any with less than 101 non-zero observations.
nx = sum(abs(x) > 0);
keepvar = nx > 100;
x = x(:,keepvar);
dx = dx(:,keepvar); %#ok<*NODEF>
xnames = xnames(keepvar);
dxnames = dxnames(keepvar);
px = numel(xnames);

% There are some perfectly correlated columns.  Delete these and any
% columns that are effectively perfectly correlated as well.
CMat = corr(x);
CMatLow = abs(tril(CMat,-1));
[I,~] = find(CMatLow > 1-1e-6);
keepvar = setdiff((1:px),I);
x = x(:,keepvar);
dx = dx(:,keepvar);
xnames = xnames(keepvar);
dxnames = dxnames(keepvar);

% There are still columns that are going to cause problems for OLS with
% everything, though we could leave them in at this point.  Let's drop them
% to make OLS comparison more fair.  Since we don't know what the variables
% are, we'll just drop them algebraically.

% Find things that would cause a problem if you ran the regression using
% only the treatment observations.
[~,Rx,ex] = qr(x(d == 1,:),0);
keepvar = ex(abs(diag(Rx)) > 1e-6);
x = x(:,keepvar);
dx = dx(:,keepvar);
xnames = xnames(keepvar);
dxnames = dxnames(keepvar);

% Find any other things that would cause a problem if you ran the
% regression using only the control observations.
[~,Rx,ex] = qr(x(d == 0,:),0);
keepvar = ex(abs(diag(Rx)) > 1e-6);
x = x(:,keepvar);
dx = dx(:,keepvar);
xnames = xnames(keepvar);
dxnames = dxnames(keepvar);
px = numel(xnames);

% Make sure interactions and main variables are all located where I think
% they are
dxnames_noint = strrep(dxnames,'_interaction','');
corr_locate = zeros(px,2);
for jj = 1:px
    corr_locate(jj,:) = [jj,find(strcmp(xnames{jj},dxnames_noint))];
end
mean(corr_locate(:,2)-corr_locate(:,1))
std(corr_locate(:,2)-corr_locate(:,1))

% Make big matrix with all the regressors
% d = sparse(d);
% x = sparse(x);
% dx = sparse(dx);
Z = [ones(n,1) x d dx];
[~,p] = size(Z);

% Subset of treatment observations
yd1 = y(d == 1,:);
Xd1 = x(d == 1,:);

% Subset of control observations
yd0 = y(d == 0,:);
Xd0 = x(d == 0,:);

%% OLS with everything
bolsCAT = Z\y;
[solsCAT,VolsCAT] = hetero_se(Z,y-Z*bolsCAT,inv(Z'*Z));


%% Lasso with plug-in tuning
% Control observations
% Initialize with 5 most marginally correlated variables
czy0 = corr(yd0,Xd0);
czys = sort(-abs(czy0));
suppy0 = find(abs(czy0) >= -czys(5));
tmpy0 = (Xd0(:,suppy0)-ones(sum(d==0),1)*mean(Xd0(:,suppy0)))\(yd0-mean(yd0));
by0 = zeros(size(Xd0,2),1);
by0(suppy0) = tmpy0;
blas0 = feasiblePostLasso(yd0-mean(yd0),Xd0-ones(sum(d==0),1)*mean(Xd0),...
    'MaxIter',1,'beta0',by0);
blas0 = [(mean(yd0)-mean(Xd0)*blas0);blas0];
use0 = blas0 ~= 0;

% Treatment observations
czy1 = corr(yd1,Xd1);
czys = sort(-abs(czy1));
suppy1 = find(abs(czy1) >= -czys(5));
tmpy1 = (Xd1(:,suppy1)-ones(sum(d),1)*mean(Xd1(:,suppy1)))\(yd1-mean(yd1));
by1 = zeros(size(Xd1,2),1);
by1(suppy1) = tmpy1;
blas1 = feasiblePostLasso(yd1-mean(yd1),Xd1-ones(sum(d),1)*mean(Xd1),...
    'MaxIter',1,'beta0',by1);
blas1 = [(mean(yd1)-mean(Xd1)*blas1);blas1];
use1 = blas1 ~= 0;

% Post-lasso
usepl = union(find(use0),find(use1));
usepl = [usepl ; (usepl+px+1)];
Zpl = Z(:,usepl);
bpl = Zpl\y;
[sepl,Vpl] = hetero_se(Zpl,y-Zpl*bpl,inv(Zpl'*Zpl));

bplCAT = zeros(size(bolsCAT));
bplCAT(usepl) = bpl;
splCAT = zeros(size(bolsCAT));
splCAT(usepl) = sepl;
VplCAT = zeros(size(VolsCAT));
VplCAT(usepl,usepl) = Vpl;


%% "Counterfactual" X
Xte = [zeros(n,(px+1)) , ones(n,1) , x];

%% Average profit differential relative to treating everyone

% OLS estimates
pi1_ols = CAT_prof1(bolsCAT,Xte,.3,.7);
spi1_ols= CAT_prof1se(bolsCAT,Xte,.3,.7,1e-6,Z,(1:size(Z,2))',y-Z*bolsCAT);

% Post-lasso estimates
pi1_pl = CAT_prof1(bplCAT,Xte,.3,.7);
spi1_pl = CAT_prof1se(bplCAT,Xte,.3,.7,1e-6,Z,usepl,y-Z*bplCAT);

% TU estimates
[TUpi1_lb,TUpi1_ub] = fsel_pi1(Z,y,Xte,usepl,.3,.7,ceil(log(n)));

%% Average profit differential relative to treating no one
% OLS estimates
pi0_ols = CAT_prof0(bolsCAT,Xte,.3,.7);
spi0_ols = CAT_prof0se(bolsCAT,Xte,.3,.7,1e-6,Z,(1:size(Z,2))',y-Z*bolsCAT);

% Post-lasso estimates
pi0_pl = CAT_prof0(bplCAT,Xte,.3,.7);
spi0_pl = CAT_prof0se(bplCAT,Xte,.3,.7,1e-6,Z,usepl,y-Z*bplCAT);

% TU estimates
[TUpi0_lb,TUpi0_ub] = fsel_pi0(Z,y,Xte,usepl,.3,.7,ceil(log(n)));

%% Save results
save catalog_prof.mat pi1_ols spi1_ols pi1_pl spi1_pl TUpi1_lb TUpi1_ub ...
    pi0_ols spi0_ols pi0_pl spi0_pl TUpi0_lb TUpi0_ub ;

